
# load packages
library(tidyverse)
library(brms)

set.seed(6157518)

# load data
df <- read_rds("output/respondent-data.rds") %>%
  mutate(issue_id = str_c("Issue ", issue_id)) %>%
  left_join(read_csv("output/issue-meta-w-awareness.csv")) %>%
  mutate(awareness_std = arm::rescale(awareness),
         treat_indicator_std = treat_indicator - 0.5) %>%
  glimpse() %>%
  write_rds("output/rescaled-data.rds")

# establish a fit priority for convenience
fit_priority <- tibble::tribble(
  ~model_code, ~policy, ~issue, ~category,           ~initial_looic,
  "2-0-2",      2L,     0L,        2L, 69990.4664174618,
  "2-2-0",      2L,     2L,        0L, 69990.5586309033,
  "1-1-2",      1L,     1L,        2L, 69990.5602425716,
  "2-0-0",      2L,     0L,        0L, 69990.8590094858,
  "1-0-2",      1L,     0L,        2L, 69990.9852025904,
  "2-0-3",      2L,     0L,        3L, 69991.0285246165,
  "2-0-1",      2L,     0L,        1L, 69991.0595521636,
  "2-1-3",      2L,     1L,        3L,  69991.073125958,
  "2-1-1",      2L,     1L,        1L, 69991.1336156245,
  "1-1-3",      1L,     1L,        3L, 69991.3659886127,
  "1-0-3",      1L,     0L,        3L, 69991.3779085103,
  "2-1-0",      2L,     1L,        0L, 69991.8060047729,
  "2-2-3",      2L,     2L,        3L, 69991.9002036819,
  "1-2-2",      1L,     2L,        2L, 69992.1218610964,
  "2-3-1",      2L,     3L,        1L, 69992.1793531722,
  "2-3-2",      2L,     3L,        2L, 69992.2129317499,
  "2-2-2",      2L,     2L,        2L, 69992.3142816402,
  "1-1-0",      1L,     1L,        0L, 69992.4141345042,
  "1-2-0",      1L,     2L,        0L, 69992.4578237458,
  "1-3-2",      1L,     3L,        2L, 69992.4856365493,
  "1-2-1",      1L,     2L,        1L, 69992.7342614657,
  "1-3-1",      1L,     3L,        1L, 69992.7478877952,
  "1-0-0",      1L,     0L,        0L, 69992.8282725104,
  "2-3-0",      2L,     3L,        0L, 69992.8518504427,
  "1-2-3",      1L,     2L,        3L, 69992.8788278912,
  "2-1-2",      2L,     1L,        2L, 69992.9660239896,
  "2-2-1",      2L,     2L,        1L, 69992.9679111324,
  "1-1-1",      1L,     1L,        1L, 69993.3177440469,
  "2-3-3",      2L,     3L,        3L, 69993.3898992982,
  "1-3-0",      1L,     3L,        0L, 69993.5081857838,
  "1-0-1",      1L,     0L,        1L, 69994.4414291853,
  "1-3-3",      1L,     3L,        3L, 69994.5480974336,
  "0-1-2",      0L,     1L,        2L, 70001.1636423487,
  "0-1-3",      0L,     1L,        3L,  70001.212321449,
  "0-3-2",      0L,     3L,        2L, 70001.8270288388,
  "0-3-0",      0L,     3L,        0L, 70002.2157392613,
  "0-2-3",      0L,     2L,        3L, 70002.4136347921,
  "0-2-2",      0L,     2L,        2L, 70002.6985788483,
  "0-3-3",      0L,     3L,        3L, 70002.7553848009,
  "0-3-1",      0L,     3L,        1L, 70003.0646151727,
  "0-1-0",      0L,     1L,        0L, 70003.2049867888,
  "0-2-1",      0L,     2L,        1L, 70003.5975400312,
  "0-1-1",      0L,     1L,        1L, 70003.7720320321,
  "0-2-0",      0L,     2L,        0L, 70003.8421227923,
  "0-0-3",      0L,     0L,        3L, 70033.0674083437,
  "0-0-2",      0L,     0L,        2L, 70033.4040606021,
  "0-0-1",      0L,     0L,        1L, 70037.5060798342,
  "0-0-0",      0L,     0L,        0L,  70039.291621577
) %>% 
  mutate(fit_priority = rank(initial_looic, ties.method = "random")) %>%
  select(-initial_looic) %>%
  glimpse()


# create a function to build the model formulas for a variety of re structures 
build_formula <- function(policy, issue, category) {
  
  # fixed effects component
  fixed_effects <- agree_measure ~ treat_indicator_std*awareness_std
  
  # create policy RE
  if(policy == 0)  f_p <- fixed_effects
  if(policy == 1)  f_p <- update(fixed_effects, . ~ . + (1 | policy))
  if(policy == 2)  f_p <- update(fixed_effects, . ~ . + (1 + treat_indicator_std | policy))
  if(policy == 3)  f_p <- update(fixed_effects, . ~ . + (1 + treat_indicator_std*awareness_std | policy))
  # create issue RE
  if(issue == 0)  f_i <- f_p
  if(issue == 1)  f_i <- update(f_p, . ~ . + (1 | issue))
  if(issue == 2)  f_i <- update(f_p, . ~ . + (1 + treat_indicator_std | issue))
  if(issue == 3)  f_i <- update(f_p, . ~ . + (1 + treat_indicator_std*awareness_std | issue))
  # create category RE
  if(category == 0)  f <- f_i
  if(category == 1)  f <- update(f_i, . ~ . + (1 | category))
  if(category == 2)  f <- update(f_i, . ~ . + (1 + treat_indicator_std | category))
  if(category == 3)  f <- update(f_i, . ~ . + (1 + treat_indicator_std*awareness_std | category))  
  
  return(f)
}

# apply the formula to build the re structures
fits <- crossing(policy = 0:3,
                         issue = 0:3,
                         category = 0:3) %>% 
  arrange(category, issue, policy) %>%
  mutate(formula = pmap(list(policy, issue, category), build_formula)) %>%
  filter(policy < 3) %>%
  mutate(model_code = paste(policy, issue, category, sep = "-")) %>%
  mutate(max_rhat = NA, 
         looic = NA, 
         fit_time = NA, 
         loo_time = NA, 
         fit = NA, 
         loo = NA) %>%
  left_join(fit_priority) %>%
  arrange(fit_priority) %>%
  glimpse()

# # prior distributions
# f <- agree_measure ~ treat_indicator_std*awareness_std +
#   (1 + treat_indicator_std | policy) +
#   (1 + treat_indicator_std*awareness | issue) +
#   (1 + treat_indicator_std*awareness | category)
# get_prior(f, data = df)
# # prior     class                              coef    group resp dpar nlpar bound
# # 1                             b
# # 2                             b                     awareness_std
# # 3                             b               treat_indicator_std
# # 4                             b treat_indicator_std:awareness_std
# # 5              lkj(1)       cor
# # 6                           cor                                   category
# # 7                           cor                                      issue
# # 8                           cor                                     policy
# # 9  student_t(3, 5, 3) Intercept
# # 10 student_t(3, 0, 3)        sd
# # 11                           sd                                   category
# # 12                           sd                         awareness category
# # 13                           sd                         Intercept category
# # 14                           sd               treat_indicator_std category
# # 15                           sd     treat_indicator_std:awareness category
# # 16                           sd                                      issue
# # 17                           sd                         awareness    issue
# # 18                           sd                         Intercept    issue
# # 19                           sd               treat_indicator_std    issue
# # 20                           sd     treat_indicator_std:awareness    issue
# # 21                           sd                                     policy
# # 22                           sd                         Intercept   policy
# # 23                           sd               treat_indicator_std   policy
# # 24 student_t(3, 0, 3)     sigma

mprior <- prior(student_t(3, 0, 1), class = "b") + 
  prior(student_t(3, 0, 1), class = "sd") + 
  prior(student_t(3, 3.5, 3), class = "Intercept") + 
  prior(lkj(3), class = "cor") + 
  prior(student_t(3, 0, 5), class = "sigma")

# # evaluate prior
# prior_sample <- brm(formula = f,
#                data = df,
#                algorithm = "sampling",
#                iter = 3000,
#                chains = 3,
#                cores = 3,
#                seed = 409881,
#                prior = mprior,
#                sample_prior = "only",
#                control = list(adapt_delta = 0.90))
# 
# cors <- gather_draws(prior_sample,
#                      cor_category__Intercept__treat_indicator_std,
#                      cor_policy__Intercept__treat_indicator_std)
# ggplot(cors, aes(x = .value)) +
#   geom_histogram() +
#   facet_wrap(vars(.variable))
# cors %>%
#   group_by(.variable) %>%
#   summarize(q10 = quantile(.value, .10),
#             q90 = quantile(.value, .90))
# # A tibble: 2 x 3
#    .variable                                       q10   q90
#    <chr>                                         <dbl> <dbl>
#   1 cor_category__Intercept__treat_indicator_std -0.446 0.446
#   2 cor_policy__Intercept__treat_indicator_std   -0.505 0.502

# make_stancode(f, data = df, prior = mprior)

# fit parameters
iter <- 8000
low_iter <- 2000
chains <- 3
keep <- 2000
cores <- chains
algorithm <- "sampling"

for (i in 1:nrow(fits)) {
  tmp_name <- fits$model_code[i]
  save_as <- paste0("output/fits/model-",  sprintf("%02d", i), "-structure-", tmp_name, ".rds")
  print(tmp_name)
  
  iter_star <- ifelse(fits$fit_priority[i] <= 4, iter, low_iter)
  mprior_star <- mprior[mprior$class %in% unique(get_prior(fits$formula[[i]], data = df)$class), ]
  
  fit_start <- Sys.time()
  tmp_fit <- brm(formula = fits$formula[[i]],
                 data = df,
                 algorithm = algorithm,
                 iter = iter_star,
                 chains = chains,   
                 cores = cores, 
                 seed = 3440956,
                 prior = mprior_star, 
                 thin = iter/keep,
                 control = list(adapt_delta = 0.95)) %>%
    write_rds(save_as)
  fit_end <- Sys.time()
  
  loo_start <- Sys.time()
  tmp_loo <- loo(tmp_fit, model_names = tmp_name)
  loo_end <- Sys.time()

  fits$looic[i] <- tmp_loo$estimates[[3]]
  fits$max_rhat[i] <- max(rhat(tmp_fit))
  fits$fit_time[i] <- fit_end - fit_start
  fits$loo_time[i] <- loo_end - loo_start
  fits$fit[[i]] <- list(tmp_fit)[1]
  fits$loo[[i]] <- list(tmp_loo)[1]
  
  write_rds(fits, "output/fits/_all-fits.rds")
}




